CREATE OR REPLACE FUNCTION "pipeline"."prefix_span"("source_table" varchar, "out_table" varchar, "feature_cols" varchar, "min_frequency" int4, "min_support" float8)
  RETURNS "pg_catalog"."text" AS $BODY$
  import numpy as np
  import json
  import time
  import math
  import itertools
  import collections
  import re
  import operator
  import ast

  #初始化返回值
  result = {
    "status": 0,
    "error_msg": "success",
    "result": {
      "input_params": {
        "out_table": out_table,
        "source_table": source_table,
        "feature_cols": feature_cols,
        "min_frequency": min_frequency,
        "min_support": min_support
      },
      "output_params": [{
        "output_cols": ["_record_id_", "pattern", "frequency", "support"],
        "out_table_name": None
      }]
    }
  }

  class prefixspan_one_item_one_event:

    def __init__(self, sdb, min_support):

        self.found_patterns = {}
        self.sdb = sdb
        self.min_support = min_support

    def find_following_pattern(self, item, db):
        mdb = []
        for sequence in db:

            if item not in sequence:
                pass
            else:
                i_item = sequence.index(item)

                mdb.append(sequence[(i_item+1):])
        return mdb

    def item_candidate(self, db, min_support):

        item_supports = collections.Counter(itertools.chain(*map(lambda x: set(x), db)))
        return [i for i, j in item_supports.items() if j >=min_support]

    def prefixspan(self, patt, db, min_support):

        if len(db) == 0:
            return None
        else:

            item_candidates = self.item_candidate(db, min_support)
            if len(item_candidates) == 0:
                return None

            for item in item_candidates:
                mdb = self.find_following_pattern(item, db)
                if len(mdb) >= min_support:
                    new_patt = patt + [item]

                    self.found_patterns[str(new_patt)] = self.found_patterns.get(str(new_patt), 0) + len(mdb)

                    self.prefixspan(new_patt, mdb, min_support)

    def exect(self):
        self.prefixspan([], self.sdb, self.min_support)


  class prefixspan_multiple_items_one_event:

    def __init__(self, sdb, min_support):

        self.found_patterns = {}
        self.sdb = sdb
        self.min_support = min_support

    def find_combinations(self, x):

        #result_list = []
        add_x = x.intersection({'_'})
        x = x.difference({'_'})

        return list(map(lambda ii: str(set({ii}).union(add_x)), x))

        #xx = itertools.combinations(x, 1)
        #xx = list(map(lambda ii: str(set(ii).union(add_x)), xx))
        #result_list.extend(xx)

        #return result_list


    def find_items_from_one_sequence(self, x):


        return list(set(itertools.chain(*map(self.find_combinations, x))))


    def find_items(self, db):

        return list(itertools.chain(*map(self.find_items_from_one_sequence, db)))


    def item_candidate(self, db, min_support):

        item_supports = collections.Counter(self.find_items(db))
        return [ast.literal_eval(i) for i, j in item_supports.items() if j >=min_support]


    def find_following_pattern(self, item, db):
        mdb = []
        for sequence in db:

            first_hit = 0
            is_hit = False
            for event_index, event in enumerate(sequence):


                if '_' not in item:

                    if (item.issubset(event)) and ('_' not in event):
                        first_hit = event_index
                        is_hit = True
                        break
                    else:
                        pass

                else:

                    if item.issubset(event):
                        first_hit = event_index
                        is_hit = True
                        break
                    else:
                        pass


            if is_hit:
                new_event = sequence[first_hit].difference(item)

                if len(new_event) > 0:
                    new_event = new_event.union({'_'})
                    mdb.append([new_event] + sequence[(first_hit+1):])
                else:
                    append_list = sequence[(first_hit+1):]

                    mdb.append(append_list)
            else:
                pass

        return mdb

    def prefixspan(self, patt, db, min_support):

        if len(db) == 0:
            return None
        else:
            item_candidates = self.item_candidate(db, min_support)

            if len(item_candidates) == 0:
                return None

            for item in item_candidates:
                mdb = self.find_following_pattern(item, db)
                if len(mdb) >= min_support:

                    new_patt = patt + [item]
                    self.found_patterns[str(new_patt)] = self.found_patterns.get(str(new_patt), 0) + len(mdb)

                    self.prefixspan(new_patt, mdb, min_support)

    def merge_multiple_events(self, x):

        selected_index = [i for i, j in enumerate(x) if '_' in j]

        if len(selected_index) > 0:


            x[selected_index[0]-1] = x[selected_index[0]-1].union(x[selected_index[0]]).difference({'_'})
            x = x[:selected_index[0]] + x[(selected_index[0]+1):]

            return [] + self.merge_multiple_events(x)

        else:
            return x

    def exect(self):

        self.found_patterns = {}
        self.prefixspan([], self.sdb, self.min_support)
        self.found_patterns = [(self.merge_multiple_events(ast.literal_eval(i)), j) for i, j in self.found_patterns.items()]

  #读数据
  def get_data_sql(source_table, feature_col):
    sql_str = "select \"%s\" from %s" %(feature_col, source_table)
    return sql_str

  def getTableMetaInfo(table_name):
    tmps = table_name.split(".")
    if len(tmps) > 0:
      table_name = tmps[1]
    sql_str = "select column_name, data_type from information_schema.columns where table_name = '%s'" %(table_name)
    rs = plpy.execute(sql_str)
    meta = {}
    for line in rs:
      meta[line['column_name']] = line['data_type']
    return meta

  def validate(featureCols, meta):
    flag = True
    isNumer = False
    msg = "success"
    numberCnt = 0
    numericTypes = ["smallint", "integer", "decimal", "numeric", "real", "double precision", "serial", "bigserial"]
    for feature in featureCols:
      dataType = meta.get(feature.replace('"', "").replace("'", ""))
      if dataType == None:
        return False, isNumer, "key={} not exists".format(feature)
      if dataType.lower() in numericTypes:
        numberCnt += 1
    if numberCnt == 0:
      isNumer = False
    elif numberCnt == len(featureCols):
      isNumer = True
    else:
      flag = False
      msg = "features type is not same"
    return flag, isNumer, msg

  def createTmpTable(source_table, featureCols, isNumer):
    flag = True
    msg = "success"
    sqlTpl = """
      CREATE VIEW {} AS SELECT "pipeline"."{}"('{}') AS event FROM {}
    """
    tmpTableName = "view_fpgrowth_{}".format(int(time.time()))
    functionName = "cols_to_vec_numeric"
    if not isNumer:
      functionName = "cols_to_vec_char"
    sqlStr = sqlTpl.format(tmpTableName, functionName, ",".join(featureCols), source_table)
    try:
      plpy.execute(sqlStr)
    except Exception as e:
      msg = str(e) + " sql=" + sqlStr
      plpy.execute("DROP VIEW IF EXISTS {}".format(tmpTableName))
      return False, tmpTableName, msg
    return flag, tmpTableName, msg

  #创建输出表
  def saveOutTable(datas, out_table):
    value_tpl = "({}, '{}', {}, {})"
    #table_name = "%s_%s" %(out_table, int(time.time()))
    table_name = out_table
    sql_str = "CREATE TABLE %s (_record_id_ int, pattern text[], frequency int, support float);" %(table_name)
    plpy.execute(sql_str)
    sql_str = "INSERT INTO %s (_record_id_, pattern, frequency, support) VALUES" %(table_name)
    _values = []
    for item in datas:
      tmp = value_tpl.format(item[0], item[1], item[2], item[3])
      _values.append(tmp)
    sql_str += ",".join(_values)
    plpy.execute(sql_str)
    return table_name

  def throwError(e):
    global result
    result["status"] = 500
    result["error_msg"] = str(e)

  #主程序
  def begin_alg(source_table, out_table, feature_cols, min_frequency, min_support):
    global result
    featureCols = feature_cols.split(",")
    sourceTable = source_table
    if source_table.startswith("create view") or source_table.startswith("CREATE VIEW"):
        try:
            rs = plpy.execute(source_table)
            sourceTable = source_table.split(" ")[2]
        except Exception as e:
            plpy.execute("DROP VIEW IF EXISTS {}".format(sourceTable))
            msg = "sq={}, errorMsg={}".format(source_table, str(e))
            result["status"] = 500
            result["error_msg"] = msg
            return json.dumps(result)
    meta = getTableMetaInfo(sourceTable)
    flag, isNumer, msg = validate(featureCols, meta)
    if not flag:
      throwError(msg)
      return json.dumps(result)
    flag, tmpTableName, msg = createTmpTable(sourceTable, featureCols, isNumer)
    if not flag:
      throwError(msg)
      return json.dumps(result)
    feature_col = "event"
    sql_str = get_data_sql(tmpTableName, feature_col)
    rv = None
    try:
      rv = plpy.execute(sql_str)
    except Exception as e:
      throwError(e)
      plpy.execute("DROP VIEW IF EXISTS {}".format(tmpTableName))
      return json.dumps(result)

    datas = []
    for item in rv:
      datas.append(item[feature_col])
    sup_frequency = math.ceil(min_support * len(datas))
    if min_frequency < sup_frequency:
      min_frequency = sup_frequency


    pre_object = prefixspan_one_item_one_event(datas, min_frequency)
    pre_object.exect()
    frequent_items = sorted(pre_object.found_patterns.iteritems(), key=lambda i: i[1], reverse=True)
    res = []
    index = 0
    for item in frequent_items:
      index += 1
      string = str(item[0]).replace('[', '{')
      string = string.replace(']', '}')
      string = string.replace('\'', '\'\'')
      res.append([index, string, item[1], float(item[1]) / float(len(datas))])
    if len(res) == 0:
      result["error_msg"] = "No patters found."
      plpy.execute("DROP VIEW IF EXISTS {}".format(tmpTableName))
      return json.dumps(result)
    saveRet = None
    try:
      saveRet = saveOutTable(res, out_table)
    except Exception as e:
      plpy.execute("DROP VIEW IF EXISTS {}".format(tmpTableName))
      throwError(e)
      return json.dumps(result)
    plpy.execute("DROP VIEW IF EXISTS {}".format(tmpTableName))
    result["result"]["output_params"][0]["out_table_name"] = saveRet

    return json.dumps(result)

  if __name__ == '__main__':
    global result
    ret = begin_alg(source_table, out_table, feature_cols, min_frequency, min_support)
    return ret
$BODY$
  LANGUAGE plpythonu VOLATILE
  COST 100